# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
import time
from collections import OrderedDict
from datetime import datetime, timedelta

import torch
import torch.distributed as dist
from torch.nn import functional as F
import torch.distributed.checkpoint as dcp
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
                                                     get_model_state_dict,
                                                     get_state_dict, set_state_dict)
from mmengine import mkdir_or_exist
from mmengine.runner import set_random_seed
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env
import itertools

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils.import_utils import is_flash_attn_2_available

from xtuner._lite import (AutoTokenizer, get_device, get_logger,
                          get_torch_device_module)
from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_hf_code, LoadWoInit,
                                     packed_sequence, varlen_attn_is_available, profile_time_and_memory)
from xtuner._lite.datasets import (DATASET_CLS_MAP, OPENAI_CONVERT_MAP,
                                   SoftPackDataset, HardPackDataset, load_datasets)
from xtuner._lite.parallel import (ParallelSampler, get_dp_mesh, get_fsdp_mesh,
                                   get_sp_mesh, get_tp_mesh, get_world_mesh, get_same_data_mesh,
                                   pad_for_sequence_parallel, setup_parallel,
                                   reduce_sequence_parallel_loss,
                                   split_for_sequence_parallel)
from xtuner._lite.parallel.megatron import megatron_parallelize
from xtuner._lite.parallel.fsdp import clip_grad_norm_

from internlm.utils.common import assert_current_device_empty
from internlm.utils.execution_time import execution_time_collecter as etc
from torch.utils.tensorboard import SummaryWriter
import threading
import queue

assert_current_device_empty()
with etc.collect_execute_time("import_time"):
    from internlm.core.context import ParallelMode
    from internlm.core.context import global_context as gpc
    from internlm.data.build_dataloader import (
        build_train_loader_with_data_type,
    )
    from internlm.data.utils import get_lang_subset_types
    from internlm.train.pipeline import load_new_batch_with_train_state
    from internlm.data.train_state import get_train_state
    from internlm.initialize import initialize_distributed_env
    from internlm.utils.common import (
        BatchSkipper,
        catch_error_node,
        enable_pytorch_expandable_segments,
        get_current_device,
        get_gpu_id,
        get_megatron_flops,
        launch_time,
        switch_topology_aware_rank_scheduling,
    )

logger = get_logger()

DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()

SUPPORT_DATA_FORMATS = OPENAI_CONVERT_MAP.keys()


def record_tensorboard(log_dir, queue: queue.Queue):
    writer = SummaryWriter(log_dir=log_dir)
    i = 0
    while True:
        if not queue.empty():
            tag, value, step = queue.get()
            writer.add_scalar(tag, value, step)
            i += 1
            if i % 1000 == 0:
                print(f"qsize {queue.qsize()}")
        else:
            time.sleep(0.01)


class SummaryWriterWrapper(SummaryWriter):
    def __init__(
        self,
        log_dir=None,
        comment="",
        purge_step=None,
        max_queue=10,
        flush_secs=120,
        filename_suffix="",
        dataset_types=[]
    ):
        if dist.get_rank() == 0:
            self.queue = queue.Queue(maxsize=5000)
            self.thread = threading.Thread(
                target=record_tensorboard, args=(log_dir, self.queue)
            )
            self.thread.start()
        else:
            self.queue = None
            self.thread = None
        self.dataset_types = dataset_types + ["undefined"]

    def add_scalar(
        self,
        tag,
        scalar_value,
        global_step=None,
        walltime=None,
        new_style=False,
        double_precision=False,
        reduce_op=None,
    ):
        if reduce_op is not None:
            scalar_value = torch.tensor(scalar_value).cuda()
            dist.all_reduce(scalar_value, op=reduce_op)
            scalar_value = scalar_value.item()
        if dist.get_rank() == 0:
            self.queue.put((tag,scalar_value,global_step))

    # def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
    #     if dist.get_rank() == 0:
    #         return super().add_scalars(main_tag, tag_scalar_dict, global_step, walltime)

    def add_train_dynamics(self, loss, unreduced_loss, batch, steps):
        self.add_scalar("train/loss", loss, global_step=steps)
        # loss per class type
        unreduced_loss = torch.cat(unreduced_loss, dim=0).flatten()  # (B T-1)
        type_ids = batch[0]["type_ids"].to(unreduced_loss.device)  # B T
        type_ids = type_ids[:, :-1].flatten()  # (B T-1)
        type_ids[type_ids == -1] = len(self.dataset_types) - 1
        loss_scatter = torch.zeros(
            [len(self.dataset_types)],
            device=unreduced_loss.device,
            dtype=unreduced_loss.dtype,
        )
        count = torch.bincount(type_ids, minlength=len(self.dataset_types))
        print(type_ids.shape, unreduced_loss.shape)
        loss_scatter.scatter_add_(0, type_ids, unreduced_loss)
        loss_scatter = loss_scatter / (count + 1e-6)

        loss_scatter = loss_scatter.tolist()
        for i, loss in enumerate(loss_scatter):
            self.add_scalar(
                f"train/loss/{self.dataset_types[i]}", loss, global_step=steps
            )

    def add_optimize_info(self, grad_norm, train_state, steps):
        self.add_scalar("optimize/grad_norm", grad_norm, global_step=steps)
        self.add_scalar(
            "optimize/inf_nan_skip_batches",
            train_state.inf_nan_skip_batches,
            global_step=steps,
        )

    def add_data_infos(self, batch, train_state, step):
        # tokens for classes
        if dist.get_rank()==0:
            type_ids = batch[0]["type_ids"]  # B L
            type_ids[type_ids == -1] = len(self.dataset_types) - 1
            count = torch.bincount(type_ids.flatten(), minlength=len(self.dataset_types))
            count = dict(
                (f"{self.dataset_types[i]}", v) for i, v in enumerate(count.tolist())
            )
            for k, v in count.items():
                self.add_scalar("data_tokens/" + k, v, step)

            # epochs for subsets
            used_epochs = train_state.data_state_dict["used_epochs"]
            for file_name, e in used_epochs.items():
                self.add_scalar(
                    f"data_subset_epochs_rank0/{file_name}", e, step, reduce_op=None
                )  # only in rank 0

    def add_speed_info(self,tgs,e2e_tgs,step):
        self.add_scalar("speed/tgs", tgs, step,reduce_op=None)
        self.add_scalar("speed/e2e_tgs", e2e_tgs, step,reduce_op=None)


def log_format(rank, debug=False):
    sp_rank = get_sp_mesh().get_local_rank()
    dp_rank = get_dp_mesh().get_local_rank()
    tp_rank = get_tp_mesh().get_local_rank()
    fsdp_rank = get_fsdp_mesh().get_local_rank()

    formatter = f'[XTuner][RANK {rank}][DP {dp_rank}][SP {sp_rank}][TP {tp_rank}]'
    formatter += '[{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]'

    if debug:
        formatter += '[<cyan>{name}</cyan>:'
        formatter += '<cyan>{function}</cyan>:'
        formatter += '<cyan>{line}</cyan>]'

    formatter += ' <level>{message}</level>'
    return formatter


def parse_args():
    parser = argparse.ArgumentParser(description='Train LLM')

    model_args = parser.add_argument_group('model', 'Model Related Settings')
    model_args.add_argument('--llm', help='repo id or local path of the model')
    model_args.add_argument('--train-cfg', help='interntrain config file')
    model_args.add_argument(
        '--dtype',
        default='auto',
        choices=['fp16', 'bf16', 'auto'],
        help=("the dtype of the model forward. When set to 'auto', it will "
              'automatically determine whether bf16 is available, '
              'prioritizing the use of bf16.'))
    model_args.add_argument(
        '--selective-recompute',
        default=1.0,
        type=float,
        help=('the ratio of re-computation for transforemer layers. '
              'The maximum is 1; the larger the value, the less memory '
              'required for training. The default is 1, meaning all layers '
              'need to be re-computated.'))
    model_args.add_argument(
        '--shard-strategy',
        default='full',
        choices=['full', 'hybrid'],
        help=('The sharding strategy to be used for distributed training.'))
    model_args.add_argument('--cpu-offload', action='store_true', help=(''))
    model_args.add_argument('--sp-size', type=int, default=1, help='')
    model_args.add_argument(
        '--max-grad-norm', default=1, type=float, help='gradient clipping')
    parser.add_argument(
        '--work-dir',
        default='work_dirs',
        help='the dir to save logs and checkpoints')
    parser.add_argument(
        '--checkpoint-interval',
        default=-1,
        type=float,
        help=('how many steps to save a checkpoint; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument(
        '--hf-interval',
        default=-1,
        type=float,
        help=('how many steps to save a hf model; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument(
        '--max-keep-ckpts',
        type=int,
        default=-1,
        help='the maximum number of checkpoints to keep.')
    parser.add_argument(
        '--checkpoint-drop-optimizer',
        action='store_true',
        help=('only model parameters are saved when saving a checkpoint. '
              'This can significantly reduce the size of checkpoint files, '
              'but the saved checkpoints cannot be resumed.'))
    parser.add_argument(
        '--log-interval', default=1, type=int, help='log interval')
    parser.add_argument(
        '--resume', action='store_true', help='resume from the last checkpoint')
    parser.add_argument(
        '--resume-from',
        type=str,
        default=None,
        help='specify checkpoint path to be resumed from.')
    parser.add_argument(
        '--seed', type=int, default=0, help='random seed for the training')
    parser.add_argument(
        '--debug', action='store_true', help='Set logger level to `DEBUG`')
    parser.add_argument(
        '--port', type=int, default=8888, help='port')
    args = parser.parse_args()
    return args


def is_interval(step, total_steps, interval):
    return step % interval == 0 or step == total_steps


def map_meta_modules(model, meta_model):
    modules = {name: mod for name, mod in model.named_modules()}
    meta_module_map = {
        mod: modules[name]
        for name, mod in meta_model.named_modules()
    }
    return meta_module_map


def build_llm_model(args, config, world_size, dtype=torch.float32):
    with LoadWoInit():
        # TODO：测试用
        new_llm_cfg = {
            "hidden_size": 512,
            "intermediate_size": 1024,
            "num_attention_heads": 8,
            "num_hidden_layers": 6,
            "num_key_value_heads": 2,
            "tie_word_embeddings": False,
            "vocab_size": 128512
        }
        llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True, **new_llm_cfg)
        llm_cfg.use_cache = False
        llm_cfg.torch_dtype = dtype

        llm = AutoModelForCausalLM.from_config(
            config=llm_cfg,
            trust_remote_code=True,
            attn_implementation='flash_attention_2')
        # llm = AutoModelForCausalLM.from_pretrained(
        #     args.llm, config=config, attn_implementation='flash_attention_2',
        #     trust_remote_code=True)

    # Ensure all numerical values in the optimizer are fp32.
    # FSDP will use low precision during forward.
    llm.to(dtype)
    return llm


# @logger.catch
def sft(args):
    ###########################################################################
    #                           1. Environment                                #
    ###########################################################################
    with etc.collect_execute_time("init_comm_time"):
        catch_error_node(initialize_distributed_env)(
            config=args.train_cfg,
            launcher='torch',
            master_port=args.port,
            seed=args.seed,
            old_config=True
        )
    assert hasattr(gpc, "config") and gpc.config is not None

    # train_folder = gpc.config.data.train_folder
    # dataset_types, dataset_subset_types = get_lang_subset_types(train_folder)
    data_rank = gpc.get_local_rank(ParallelMode.DATA)
    data_world_size = gpc.get_world_size(ParallelMode.DATA)

    setup_parallel(sp_size=args.sp_size, tp_size=1)
    set_random_seed(args.seed)

    dp_mesh = get_dp_mesh()
    tp_mesh = get_tp_mesh()
    sp_mesh = get_sp_mesh()
    fsdp_mesh = get_fsdp_mesh()  # dp_size * sp_size
    world_mesh = get_world_mesh()  # dp_size * sp_size * tp_size

    dp_size = dp_mesh.size()
    tp_size = tp_mesh.size()
    sp_size = sp_mesh.size()
    world_size = world_mesh.size()

    # if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
    #     raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
    #                      'should be divisible by the '
    #                      f'world_size({world_size}).')
    #
    # if (args.global_batch_size / dp_size) % args.mirco_batch_size:
    #     raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
    #                      f'should be divisible by the world_size({world_size})'
    #                      f' * `mirco_batch_size`({args.mirco_batch_size})')

    rank = dist.get_rank()
    timestamp = datetime.now().strftime('%Y%m%d%H%M%S')

    objects = [timestamp]
    dist.broadcast_object_list(objects, src=0)
    timestamp = objects[0]

    args.work_dir = os.path.join(args.work_dir, timestamp)
    mkdir_or_exist(args.work_dir)

    log_file = os.path.join(args.work_dir, f'rank{rank}.log')
    dataset_types, dataset_subset_types = get_lang_subset_types(
        gpc.config.data.train_folder
    )
    tbwriter = SummaryWriterWrapper(log_dir=args.work_dir, dataset_types=dataset_types)

    # Change the log format printed in the terminal
    lvl = 'DEBUG' if args.debug else 'INFO'
    logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
    # Change the format saved in the log file
    logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)

    if rank == 0:
        env = collect_env()
        import transformers

        import xtuner
        env['Transformers'] = transformers.__version__
        env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
        runtime_env = OrderedDict()
        runtime_env.update(env)
        runtime_env['Seed'] = args.seed
        runtime_env['World Size'] = world_size
        runtime_env['DP Size'] = dp_size
        runtime_env['SP Size'] = sp_size
        runtime_env['TP Size'] = tp_size
        # runtime_env['Distributed launcher'] = dist_launcher

        runtime_env_info = '\n    ' + '\n    '.join(
            f'{k}: {v}' for k, v in runtime_env.items())
        dash_line = '-' * 60
        logger.info('\n' + dash_line + '\nRuntime environment:' +
                    runtime_env_info + '\n' + dash_line + '\n')
    # -------------------    Environment  End  ------------------------------ #
    if args.resume_from and args.resume is False:
        args.resume = True
    if args.resume is True and args.resume_from is None:
        # find last checkpoint
        save_file = os.path.join(args.work_dir, '../last_checkpoint')
        if os.path.exists(save_file):
            with open(save_file) as f:
                args.resume_from = f.read().strip()
        else:
            logger.warning('Did not find last_checkpoint to be resumed. training from scratch.')
            args.resume = False
    if args.resume:
        assert not args.checkpoint_drop_optimizer, '`resume` and `checkpoint_drop_optimizer` cannot be set at the same time.'

    ###########################################################################
    #                     replace config                                     #
    ###########################################################################
    args.wd = gpc.config.adam.weight_decay
    args.lr = gpc.config.adam.lr
    args.adam_beta1 = gpc.config.adam.adam_beta1
    args.adam_beta2 = gpc.config.adam.adam_beta2
    args.adam_epsilon = gpc.config.adam.adam_eps
    args.total_steps = gpc.config.data.total_steps
    args.iters_per_step = gpc.config.data.gradient_accumulation
    args.warmup_ratio = gpc.config.lr_scheduler.warmup_steps
    args.lr_min = gpc.config.MIN_LEARNING_RATE
    logger.info(args)
    logger.info(f"data_rank: {data_rank}, data_world_size: {data_world_size}")
    ###########################################################################
    #                     2. Dataset & Dataloader                             #
    ###########################################################################

    start_load_data_t = time.time()

    assert varlen_attn_is_available()

    with etc.collect_execute_time("load_data_time"):
        train_dl = build_train_loader_with_data_type(
            data_cfg=gpc.config.data,
            data_rank=data_rank,
            data_world_size=data_world_size,
        )
    train_state = get_train_state(train_dl)

    load_data_cost_time = time.time() - start_load_data_t
    logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')
    # -------------------    Dataset & Dataloader  End  --------------------- #

    ###########################################################################
    #                          3. FSDP                                        #
    ###########################################################################
    if args.dtype == 'auto':
        args.dtype = 'bf16' if DEVICE_MODULE.is_bf16_supported() else 'fp16'

    if args.dtype == 'fp16':
        dtype = torch.float16
        autocast = torch.amp.autocast(DEVICE, enabled=True, dtype=dtype)
        scaler = ShardedGradScaler()
    elif args.dtype == 'bf16':
        if DEVICE_MODULE.is_bf16_supported():
            dtype = torch.bfloat16
            autocast = torch.amp.autocast(DEVICE, enabled=True, dtype=dtype)
            scaler = None
        else:
            raise RuntimeError('The device does not support `bf16`, '
                               'please set `dtype` to `fp16`.')
    else:
        raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, '
                           f'but found {args.dtype}.')

    llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True)
    if is_flash_attn_2_available():
        llm_cfg.attn_implementation = 'flash_attention_2'

    llm_cfg.use_cache = False
    llm_cfg.torch_dtype = dtype

    # Only load parameters on rank 0 to avoid each rank repeatedly loading the
    # same model into the CPU, wasting memory
    if rank == 0:
        with torch.device('cpu'):
            rank0_llm = build_llm_model(args, llm_cfg, world_size, dtype)
    else:
        rank0_llm = None

    with torch.device('meta'):
        # Ensure all numerical values in the optimizer are fp32.
        # FSDP will use low precision during forward.
        llm = build_llm_model(args, llm_cfg, world_size, dtype)
        dispatch_hf_code(llm)
        for module in llm.modules():
            for p_name, param in module.named_parameters(recurse=False):
                if param.requires_grad:
                    param_fp32 = torch.nn.Parameter(
                        param.to(dtype=torch.float32))
                    setattr(module, p_name, param_fp32)

    mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=dtype)

    with profile_time_and_memory('[Parallelize LLM]'):
        megatron_parallelize(
            llm,
            rank0_llm,
            dp_mesh=fsdp_mesh,
            tp_mesh=tp_mesh,
            mp_policy=mp_policy,
            recompute_ratio=args.selective_recompute,
            reshard_after_forward=True)

        llm.train()

    if rank == 0:
        logger.info(llm)

    dist.barrier()
    # --------------------------    FSDP  End  ------------------------------ #

    ###########################################################################
    #                      4. Optimizer & Scheduler                           #
    ###########################################################################
    requried_grad_params = [
        param for param in llm.parameters() if param.requires_grad
    ]
    optimizer = AdamW(
        requried_grad_params,
        lr=args.lr,
        weight_decay=args.wd,
        betas=(args.adam_beta1, args.adam_beta2),
        eps=args.adam_epsilon)

    # `iter` means once forward+backward
    # `step` means once optimizer step
    # `iters_per_step` means gradient accumulative counts
    iters_per_step = args.iters_per_step
    total_steps = args.total_steps

    if args.checkpoint_interval == -1:
        checkpoint_interval = total_steps
    elif args.checkpoint_interval < 1:
        checkpoint_interval = int(total_steps * args.checkpoint_interval)
    else:
        checkpoint_interval = int(args.checkpoint_interval)

    if args.hf_interval == -1:
        hf_interval = total_steps
    elif args.hf_interval < 1:
        hf_interval = int(total_steps * args.hf_interval)
    else:
        hf_interval = int(args.hf_interval)

    max_keep_ckpts = args.max_keep_ckpts
    if max_keep_ckpts <= 0:
        # save all checkpoints
        max_keep_ckpts = total_steps + 100000
    save_hf_ckpt_names = []
    save_pt_ckpt_names = []

    if args.warmup_ratio < 1:
        warmup_steps = int(args.warmup_ratio * total_steps)
    else:
        warmup_steps = int(args.warmup_ratio)

    def warmup_fn(x):
        return x / warmup_steps if x < warmup_steps else 1

    warmup_scheduler = LambdaLR(optimizer, warmup_fn)

    cosine_scheduler = CosineAnnealingLR(
        optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)

    # ----------------    Optimizer & Scheduler End   ----------------------- #
    if args.resume:
        logger.info(f'[Resume] Resume from {args.resume_from}')
        _options = StateDictOptions(
            cpu_offload=True, ignore_frozen_params=True)
        (shard_model_state_dict,
         shard_optimizer_state_dict) = get_state_dict(
            llm, optimizer, options=_options)
        state_dict = {
            'model': shard_model_state_dict,
            'optimizer': shard_optimizer_state_dict,
            'train_state': train_state,
            'warmup_scheduler': warmup_scheduler,
            'cosine_scheduler': cosine_scheduler
        }
        # inplace state_dict
        dcp.load(
            state_dict=state_dict,
            checkpoint_id=args.resume_from,
        )

        _options = StateDictOptions(
            cpu_offload=True, strict=False)
        set_state_dict(
            llm,
            optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optimizer"],
            options=_options
        )

    # print('===============',train_state.batch_count)
    if train_state.batch_count >= total_steps:
        logger.info("Training has finished, exiting...")
        return

    gpc.train_state = train_state

    ###########################################################################
    #                          5. Training                                    #
    ###########################################################################
    start_train_t = time.time()
    total_consumed_tokens=0
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    max_memory = torch.cuda.max_memory_allocated()
    logger.info('[Train] Begin Train Loop. The current GPU memory is '
                f'{(max_memory / 1024 ** 3):.1f}GB')

    train_iter = iter(train_dl)

    for batch_count in itertools.count(train_state.batch_count):
        # logger.info(f'batch_count {batch_count} step_count {train_state.step_count}')
        if train_state.step_count >= total_steps:
            break

        if train_state.step_count <= warmup_steps:
            warmup_scheduler.step()
            cur_lr = warmup_scheduler.get_last_lr()[0]
        else:
            cosine_scheduler.step()
            cur_lr = cosine_scheduler.get_last_lr()[0]

        torch.cuda.reset_peak_memory_stats()

        step_loss = 0
        step_data_time = 0
        step_start_t = time.time()
        step_consumed_tokens = 0

        _data_start_t = time.time()

        step_data_list = []
        rank_grad_tokens = 0

        # the first dim is grad acc step
        batch, train_iter = load_new_batch_with_train_state(train_dl=train_dl, train_iter=train_iter,
                                                            train_state=train_state)

        inputs, labels = batch
        input_ids = inputs['input_ids']
        cu_seqlens = inputs['cu_seqlens']
        assert input_ids.shape[0] == iters_per_step

        gpc.config.batch_count = batch_count
        train_state.batch_count = batch_count
        train_state.num_consumed_samples_in_epoch += len(batch[1])
        # logger.info(
        #    f'batch_count {batch_count} step_count {train_state.step_count} num_consumed_samples_in_epoch {train_state.num_consumed_samples_in_epoch}')
        for _iter in range(iters_per_step):
            input_ids_iter = input_ids[_iter: _iter + 1]
            labels_iter = labels[_iter: _iter + 1]
            cu_seqlens_iter = cu_seqlens[_iter]
            num_token = cu_seqlens_iter[1:] - cu_seqlens_iter[:-1]
            num_token[-1] = num_token[-1] - 1
            if num_token[-1] == 0:
                num_token = num_token[:-1]

            # labels already offset
            labels_iter = labels_iter[:, :-1]
            input_ids_iter = input_ids_iter[:, :-1]

            rank_grad_tokens += (labels_iter >= 0).sum()
            step_data_list.append({"input_ids": input_ids_iter,
                                   "labels": labels_iter,
                                   "num_tokens": num_token})

        rank_grad_tokens = rank_grad_tokens.to(DEVICE)
        dist.all_reduce(rank_grad_tokens)
        global_grad_tokens = rank_grad_tokens / tp_size / sp_size

        step_data_time = time.time() - _data_start_t
        unreduced_losses=[]
        for _iter in range(iters_per_step):
            data = step_data_list[_iter]
            input_ids = data['input_ids'].to(DEVICE)
            labels = data['labels'].to(DEVICE)
            num_tokens = data['num_tokens'].to(DEVICE)

            if sp_size > 1:
                # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
                input_ids = pad_for_sequence_parallel(input_ids, 0, sp_mesh, dim=1)
                _num_pad = input_ids.numel() - num_tokens.sum()
                if _num_pad > 0:
                    _num_pad = torch.IntTensor([_num_pad]).to(DEVICE)
                    num_tokens = torch.cat([num_tokens, _num_pad], dim=-1)

                input_ids = split_for_sequence_parallel(
                    input_ids, dim=1, sp_mesh=sp_mesh)

                labels = pad_for_sequence_parallel(labels, -100, sp_mesh, dim=1)
                labels = split_for_sequence_parallel(
                    labels, dim=1, sp_mesh=sp_mesh)

            packed_ctx = packed_sequence(num_tokens, sp_mesh=sp_mesh)

            with packed_ctx:

                logits = llm(input_ids=input_ids, use_cache=False).logits

                loss = F.cross_entropy(logits.squeeze(), labels.squeeze(), reduction='none')  # 1, seqlen
                unreduced_losses.append(loss.detach().clone())
                if sp_size > 1:
                    sp_group = sp_mesh.get_group()
                    sp_pt_loss = dist.nn.functional.all_gather(loss, sp_group)
                    sp_pt_labels = dist.nn.functional.all_gather(labels, sp_group)

                    loss = torch.cat(sp_pt_loss, dim=-1)
                    labels = torch.cat(sp_pt_labels, dim=-1)

                loss = loss.sum() / global_grad_tokens * dp_size

                loss.backward()

            step_loss += loss.item()

            step_consumed_tokens += num_tokens.sum() / sp_size / tp_size

            train_state.step_count += 1
        grad_norm = clip_grad_norm_(
            requried_grad_params, fsdp_mesh, args.max_grad_norm)
        if grad_norm.isnan():
            train_state.inf_nan_skip_batches += 1

        optimizer.step()
        optimizer.zero_grad()

        step_time = time.time() - step_start_t
        eta = step_time * (total_steps - train_state.step_count)
        eta = timedelta(seconds=int(eta))
        tgs = int(step_consumed_tokens / step_time)
        max_memory = torch.cuda.max_memory_allocated()
        total_consumed_tokens += step_consumed_tokens
        end2end_tgs = int(total_consumed_tokens / (time.time() - start_train_t))

        # log to tensorboard
        tensorboard_start_time=time.time()
        tbwriter.add_data_infos(batch, train_state, batch_count)
        tbwriter.add_train_dynamics(step_loss, unreduced_losses, batch, batch_count)
        tbwriter.add_optimize_info(grad_norm.detach().clone(), train_state, batch_count)
        tbwriter.add_speed_info(tgs, end2end_tgs, batch_count)
        tensorboard_time=time.time()-tensorboard_start_time

        if is_interval(train_state.step_count, total_steps, args.log_interval):
            logger.info(f'[Train] (Epoch 1) Step '
                        f'{train_state.step_count}/{total_steps}  '
                        f'lr: {cur_lr:.6f}  loss: {step_loss:.3f}  '
                        f'grad_norm: {grad_norm:.2f}  '
                        f'max_memory: {(max_memory / 1024 ** 3):.1f}GB  '
                        f'text_tokens: {step_consumed_tokens}  '
                        f'tgs: {tgs} e2e_tgs: {end2end_tgs} data_time: {step_data_time:.2f}s  '
                        f'time: {step_time:.2f}s tb_time: {tensorboard_time:.2f} '
                        f'eta: {eta}')

        num_digits = len(str(abs(total_steps)))
        if is_interval(train_state.step_count, total_steps, hf_interval):
            DEVICE_MODULE.empty_cache()
            hf_dir = os.path.join(args.work_dir, f'hf-{train_state.step_count:0{num_digits}}')

            with profile_time_and_memory('[HF Checkpoint]'):

                from torch.distributed._tensor import DTensor

                if rank == 0:
                    llm_state_dict = {}

                for name, param in llm.state_dict().items():
                    if isinstance(param, DTensor):
                        with torch.no_grad():
                            full_param = param.full_tensor().cpu()
                    else:
                        full_param = param.cpu()

                    if rank == 0:
                        llm_state_dict[name] = full_param

                if rank == 0:
                    rank0_llm.load_state_dict(llm_state_dict)
                    rank0_llm.save_pretrained(hf_dir)
                    # tokenizer.save_pretrained(hf_dir)

                dist.barrier()

            if dist.get_rank() == 0:
                save_hf_ckpt_names.append(hf_dir)
                if len(save_hf_ckpt_names) > max_keep_ckpts:
                    remove_hf_ckpt_name = save_hf_ckpt_names.pop(0)
                    os.system(f'rm -rf {remove_hf_ckpt_name}')

            max_memory = torch.cuda.max_memory_allocated()
            logger.info('[HF Checkpoint] During saving HF checkpoint, the peak GPU '
                        f'memory is {max_memory / 1024 ** 3:.1f}GB.')

        if is_interval(train_state.step_count, total_steps, checkpoint_interval):
            if args.checkpoint_drop_optimizer:
                logger.warning('The saved checkpoint cannot be resumed. '
                               'If you want to save a resumable checkpoint, '
                               'please remove `--checkpoint-drop-optimizer` '
                               'from the command.')
            else:
                with profile_time_and_memory('[PT Checkpoint]'):
                    # FSDP cannot be saved via torch.save
                    # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html  # noqa: E501
                    _options = StateDictOptions(
                        cpu_offload=True, ignore_frozen_params=True)
                    (shard_model_state_dict,
                     shard_optimizer_state_dict) = get_state_dict(
                        llm, optimizer, options=_options)

                    state_dict = {
                        'model': shard_model_state_dict,
                        'optimizer': shard_optimizer_state_dict,
                        'warmup_scheduler': warmup_scheduler.state_dict(),
                        'cosine_scheduler': cosine_scheduler.state_dict(),
                        'train_state': train_state.state_dict(),
                    }
                    ckpt_id = f'{train_state.step_count:0{num_digits}}-of-{total_steps:0{num_digits}}'
                    ckpt_dir = os.path.join(args.work_dir, f'ckpt-{ckpt_id}')
                    dcp.save(state_dict, checkpoint_id=ckpt_dir)

                if dist.get_rank() == 0:
                    save_file = os.path.join(args.work_dir, '../', 'last_checkpoint')
                    with open(save_file, 'w') as f:
                        f.write(ckpt_dir)

                    save_pt_ckpt_names.append(ckpt_dir)
                    if len(save_pt_ckpt_names) > max_keep_ckpts:
                        remove_pt_ckpt_name = save_pt_ckpt_names.pop(0)
                        os.system(f'rm -rf {remove_pt_ckpt_name}')

            max_memory = torch.cuda.max_memory_allocated()
            logger.info('[Checkpoint] During saving checkpoint, the peak GPU '
                        f'memory is {max_memory / 1024 ** 3:.1f}GB.')

    train_cost_time = time.time() - start_train_t
    logger.info(f'[Train] Cost {timedelta(seconds=int(train_cost_time))}')
    # ------------------------    Training  End  ---------------------------- #


if __name__ == '__main__':
    args = parse_args()
    sft(args)
